import autograd
import autograd.numpy as np
import scipy.integrate
import torch
import argparse
import os
import sys
import random
from simulators.nn_models import MLPAutoencoder
from simulators.hnn import HNN, PixelHNN
from simulators.data import get_dataset
from simulators.utils import L2_loss, plot_results, plot_results_testing, process_stats
from simulators.noda import NODA, AE
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchsummary
import pdb


def get_args():
    parser = argparse.ArgumentParser(description=None)
    parser.add_argument('--input_dim', default=28**2, type=int, help='dimensionality of input tensor')
    parser.add_argument('--hidden_dim_mlp', default=2, type=int, help='hidden dimension of mlp')
    parser.add_argument('--hidden_dim_ae', default=2, type=int, help='hidden dimension of auto-encoder')
    parser.add_argument('--latent_dim', default=2, type=int, help='latent dimension of autoencoder')
    parser.add_argument('--learn_rate', default=1e-3, type=float, help='learning rate')
    parser.add_argument('--input_noise', default=0.0, type=float, help='std of noise added to inputs')
    parser.add_argument('--batch_size', default=200, type=int, help='batch size')
    parser.add_argument('--nonlinearity', default='tanh', type=str, help='neural net nonlinearity')
    parser.add_argument('--total_steps', default=3000, type=int, help='number of gradient steps')
    parser.add_argument('--print_every', default=200, type=int, help='number of gradient steps between prints')
    parser.add_argument('--verbose', dest='verbose', action='store_true', help='verbose?')
    parser.add_argument('--name', default='pixels', type=str, help='either "real" or "sim" data')
    parser.add_argument('--seed', default=0, type=int, help='random seed')
    parser.add_argument('--retrain', action='store_true', default=False, help='whether not not retrain the models')
    parser.add_argument('--save_dir', default='results/simulators/pixel_pendulum', type=str,
                        help='where to save the data')
    parser.set_defaults(feature=True)
    return parser.parse_args()


def pixelhnn_loss(x, x_next, model, return_scalar=True):
    # encode pixel space -> latent dimension
    z = model.encode(x)
    z_next = model.encode(x_next)

    # autoencoder loss
    x_hat = model.decode(z)
    ae_loss = ((x - x_hat)**2).mean(1)

    # hnn vector field loss
    noise = args.input_noise * torch.randn(*z.shape)
    z_hat_next = z + model.time_derivative(z + noise)  # replace with rk4
    hnn_loss = ((z_next - z_hat_next)**2).mean(1)

    # canonical coordinate loss
    # -> makes latent space look like (x, v) coordinates
    w, dw = z.split(1, 1)
    w_next, _ = z_next.split(1, 1)
    cc_loss = ((dw - (w_next - w)) ** 2).mean(1)

    # sum losses and take a gradient step
    loss = ae_loss + cc_loss + 1e-1 * hnn_loss
    if return_scalar:
        return loss.mean()
    return loss


def train_HNN(args):
    # init model and optimizer
    autoencoder = MLPAutoencoder(args.input_dim, args.hidden_dim_ae, args.latent_dim, nonlinearity='relu')
    model = PixelHNN(args.latent_dim, args.hidden_dim_mlp, autoencoder=autoencoder, nonlinearity=args.nonlinearity,
                     baseline=False)
    print("HNN has {} paramerters in total".format(sum(x.numel() for x in model.parameters() if x.requires_grad)))
    # if args.verbose:
    #     print("Training baseline model:" if args.baseline else "Training HNN model:")
    optim = torch.optim.Adam(model.parameters(), args.learn_rate, weight_decay=1e-5)

    # get dataset
    data = get_dataset('pendulum', args.save_dir, verbose=True, seed=args.seed)

    x = torch.tensor(data['pixels'], dtype=torch.float32)
    test_x = torch.tensor(data['test_pixels'], dtype=torch.float32)
    next_x = torch.tensor(data['next_pixels'], dtype=torch.float32)
    test_next_x = torch.tensor(data['test_next_pixels'], dtype=torch.float32)

    # vanilla ae train loop
    stats = {'train_loss': [], 'test_loss': []}
    with tqdm(total=args.total_steps) as t:
        for step in range(args.total_steps):
            # train step
            ixs = torch.randperm(x.shape[0])[:args.batch_size]
            loss = pixelhnn_loss(x[ixs] + args.input_noise * torch.randn(*x[ixs].shape).to(x.device),
                                 next_x[ixs] + args.input_noise * torch.randn(*next_x[ixs].shape).to(next_x.device),
                                 model)
            loss.backward()
            optim.step()
            optim.zero_grad()

            train_loss = model.get_l2_loss(x, next_x).cpu().numpy()
            test_loss = model.get_l2_loss(test_x, test_next_x).cpu().numpy()
            stats['train_loss'].append([train_loss.mean(), train_loss.std()])
            stats['test_loss'].append([test_loss.mean(), test_loss.std()])
            t.set_postfix(train_loss='{:.9f}'.format(train_loss.mean()),
                          test_loss='{:.9f}'.format(test_loss.mean()))
            if args.verbose and step % args.print_every == 0:
                # run validation
                test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size]
                test_loss = pixelhnn_loss(test_x[test_ixs], test_next_x[test_ixs], model)
                print("step {}, train_loss {:.4e}, test_loss {:.4e}".format(step, loss.item(), test_loss.item()))
            t.update()

    train_dist = model.get_l2_loss(x, next_x)
    test_dist = model.get_l2_loss(test_x, test_next_x)
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'.
          format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]),
                 test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0])))
    return model, stats


def train_NODA(args):
    use_cuda = torch.cuda.is_available()

    # init model and optimizer
    model = NODA(args.input_dim, args.hidden_dim_mlp, args.hidden_dim_ae, args.latent_dim, args.learn_rate,
                 nonlinearity=args.nonlinearity)
    if use_cuda:
        model = model.cuda()
    print("NODA has {} paramerters in total".format(sum(x.numel() for x in model.parameters() if x.requires_grad)))

    # get dataset
    data = get_dataset('pendulum', args.save_dir, verbose=True, seed=args.seed)

    x = torch.tensor(data['pixels'], dtype=torch.float32)
    test_x = torch.tensor(data['test_pixels'], dtype=torch.float32)
    next_x = torch.tensor(data['next_pixels'], dtype=torch.float32)
    test_next_x = torch.tensor(data['test_next_pixels'], dtype=torch.float32)

    if use_cuda:
        x = x.cuda()
        test_x = test_x.cuda()
        next_x = next_x.cuda()
        test_next_x = test_next_x.cuda()
    # vanilla ae train loop
    stats = {'train_loss': [], 'test_loss': []}
    with tqdm(total=args.total_steps) as t:
        for step in range(args.total_steps):
            # train step
            ixs = torch.randperm(x.shape[0])[:args.batch_size]
            loss = model.forward_train(x[ixs] +
                                       args.input_noise * torch.randn(*x[ixs].shape).to(x.device),
                                       next_x[ixs] +
                                       args.input_noise * torch.randn(*next_x[ixs].shape).to(next_x.device))

            train_loss = model.forward_train(x, next_x, False, False).cpu().numpy()
            test_loss = model.forward_train(test_x, test_next_x, False, False).cpu().numpy()
            stats['train_loss'].append([train_loss.mean(), train_loss.std()])
            stats['test_loss'].append([test_loss.mean(), test_loss.std()])
            t.set_postfix(train_loss='{:.9f}'.format(train_loss.mean()),
                          test_loss='{:.9f}'.format(test_loss.mean()))
            if args.verbose and step % args.print_every == 0:
                # run validation
                test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size]
                test_loss = model.forward_train(test_x[test_ixs], test_next_x[test_ixs], False)
                print("step {}, train_loss {:.4e}, test_loss {:.4e}"
                      .format(step, loss.item(), test_loss.item()))
            t.update()

    train_dist = model.forward_train(x, next_x, False, False)
    test_dist = model.forward_train(test_x, test_next_x, False, False)
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
          .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]),
                  test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0])))
    return model, stats


def train_AE(args):
    use_cuda = torch.cuda.is_available()

    # init model and optimizer
    model = AE(args.input_dim, args.hidden_dim_mlp, args.hidden_dim_ae, args.latent_dim, args.learn_rate,
               nonlinearity=args.nonlinearity)
    if use_cuda:
        model = model.cuda()
    print("AE has {} paramerters in total".format(sum(x.numel() for x in model.parameters() if x.requires_grad)))

    # get dataset
    data = get_dataset('pendulum', args.save_dir, verbose=True, seed=args.seed)

    x = torch.tensor(data['pixels'], dtype=torch.float32)
    test_x = torch.tensor(data['test_pixels'], dtype=torch.float32)
    next_x = torch.tensor(data['next_pixels'], dtype=torch.float32)
    test_next_x = torch.tensor(data['test_next_pixels'], dtype=torch.float32)

    if use_cuda:
        x = x.cuda()
        test_x = test_x.cuda()
        next_x = next_x.cuda()
        test_next_x = test_next_x.cuda()
    # vanilla ae train loop
    stats = {'train_loss': [], 'test_loss': []}
    with tqdm(total=args.total_steps) as t:
        for step in range(args.total_steps):
            # train step
            ixs = torch.randperm(x.shape[0])[:args.batch_size]
            loss = model.forward_train(x[ixs] +
                                       args.input_noise * torch.randn(*x[ixs].shape).to(x.device),
                                       next_x[ixs] +
                                       args.input_noise * torch.randn(*next_x[ixs].shape).to(next_x.device))

            train_loss = model.forward_train(x, next_x, False, False).cpu().numpy()
            test_loss = model.forward_train(test_x, test_next_x, False, False).cpu().numpy()
            stats['train_loss'].append([train_loss.mean(), train_loss.std()])
            stats['test_loss'].append([test_loss.mean(), test_loss.std()])
            t.set_postfix(train_loss='{:.9f}'.format(train_loss.mean()),
                          test_loss='{:.9f}'.format(test_loss.mean()))
            if args.verbose and step % args.print_every == 0:
                # run validation
                test_ixs = torch.randperm(test_x.shape[0])[:args.batch_size]
                test_loss = model.forward_train(test_x[test_ixs], test_next_x[test_ixs], False)
                print("step {}, train_loss {:.4e}, test_loss {:.4e}"
                      .format(step, loss.item(), test_loss.item()))
            t.update()

    train_dist = model.forward_train(x, next_x, False, False)
    test_dist = model.forward_train(test_x, test_next_x, False, False)
    print('Final train loss {:.4e} +/- {:.4e}\nFinal test loss {:.4e} +/- {:.4e}'
          .format(train_dist.mean().item(), train_dist.std().item() / np.sqrt(train_dist.shape[0]),
                  test_dist.mean().item(), test_dist.std().item() / np.sqrt(test_dist.shape[0])))
    return model, stats


def train(args):
    _, stats_HNN = train_HNN(args)
    _, stats_AE = train_AE(args)
    _, stats_NODA = train_NODA(args)

    stats_HNN = process_stats(stats_HNN)
    stats_AE = process_stats(stats_AE)
    stats_NODA = process_stats(stats_NODA)
    stats_list = [stats_HNN, stats_AE, stats_NODA]
    labels_list = ['HNN', 'AE', 'NODA']
    np.savez(args.save_dir + '/simulation_results_pixel_pendulum_' + str(args.hidden_dim_ae) + '_' +
             str(args.hidden_dim_mlp) + '_' + str(args.total_steps) + '.npz',
             stats_list=stats_list, labels_list=labels_list)
    return stats_list, labels_list


if __name__ == "__main__":
    os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
    solve_ivp = scipy.integrate.solve_ivp
    THIS_DIR = os.path.dirname(os.path.abspath(__file__))
    PARENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    sys.path.append(PARENT_DIR)
    args = get_args()
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    plt.rc('font', size=23)
    if args.retrain:
        stats_list, labels_list = train(args)
    else:
        try:
            results = np.load(args.save_dir + '/simulation_results_pixel_pendulum_' + str(args.hidden_dim_ae) + '_' +
                              str(args.hidden_dim_mlp) + '_' + str(args.total_steps) + '.npz',
                              allow_pickle=True)
            stats_list = [results['stats_list'][i] for i in range(len(results['stats_list']))]
            labels_list = [results['labels_list'][i] for i in range(len(results['labels_list']))]
        except:
            stats_list, labels_list = train(args)
    plot_results(args, stats_list, labels_list, title='Pixel Pendulum')
    plot_results_testing(args, stats_list, labels_list, title='Pixel Pendulum')
